# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SAM logs tool for AWS Serverless MCP Server."""

from awslabs.aws_serverless_mcp_server.tools.common.base_tool import BaseTool
from awslabs.aws_serverless_mcp_server.utils.process import run_command
from loguru import logger
from mcp.server.fastmcp import Context, FastMCP
from pydantic import Field
from typing import Any, Dict, List, Optional


class SamLogsTool(BaseTool):
    """Tool to fetch logs from AWS SAM applications using the 'sam logs' command."""

    def __init__(self, mcp: FastMCP, allow_sensitive_data_access):
        """Initialize the SAM logs tool."""
        super().__init__(allow_sensitive_data_access=allow_sensitive_data_access)
        mcp.tool(name='sam_logs')(self.handle_sam_logs)
        self.allow_sensitive_data_access = allow_sensitive_data_access

    async def handle_sam_logs(
        self,
        ctx: Context,
        resource_name: Optional[str] = Field(
            default=None,
            description="""Name of the resource to fetch logs for. This is be the logical ID of the function resource in the AWS CloudFormation/AWS SAM template.
                Multiple names can be provided by repeating the parameter again. If you don't specify this option,
                AWS SAM fetches logs for all resources in the stack that you specify. You must specify stack_name wheみ specifying resource_name.""",
        ),
        stack_name: Optional[str] = Field(
            default=None, description='Name of the CloudFormation stack'
        ),
        start_time: Optional[str] = Field(
            default=None,
            description='Fetch logs starting from this time (format: 5mins ago, tomorrow, or YYYY-MM-DD HH:MM:SS)',
        ),
        end_time: Optional[str] = Field(
            default=None,
            description='Fetch logs up until this time (format: 5mins ago, tomorrow, or YYYY-MM-DD HH:MM:SS)',
        ),
        region: Optional[str] = Field(
            default=None, description='AWS region to use (e.g., us-east-1)'
        ),
        profile: Optional[str] = Field(default=None, description='AWS profile to use'),
        cw_log_group: Optional[List[str]] = Field(
            default=None,
            description="""Use AWS CloudWatch to fetch logs. Includes logs from the CloudWatch Logs log groups that you specify.
                If you specify this option along with name, AWS SAM includes logs from the specified log groups in addition to logs from the named resources.""",
        ),
        config_env: Optional[str] = Field(
            default=None,
            description='Environment name specifying default parameter values in the configuration file',
        ),
        config_file: Optional[str] = Field(
            default=None,
            description='Absolute path to configuration file containing default parameter values',
        ),
        save_params: bool = Field(
            default=False, description='Save parameters to the SAM configuration file'
        ),
    ) -> Dict[str, Any]:
        """Fetches CloudWatch logs that are generated by Lambda function and API GW resources in a SAM application.

        Requirements:
        - AWS SAM CLI MUST be installed and configured in your environment
        - Your SAM application MUST be deployed and receiving traffic

        After deploying your serverless application, you can use this tool to monitor it to provide insights on
        its operations and detect anomalies. Use this tool to help troubleshoot invocation failures, and function code errors
        and find root causes. Lambda function logs contain application logs emitted by your code and platform level logs emitted by the Lambda service.

        Usage tips:
        - Use logs to debug out-of-memory errors. Platform logs indicate memory usage in the REPORT line. If memory usage is high compared to
        configured memory, out-of-memory could be causing invocation failures.
        - Use logs to debug timeouts errors. Functions that have timed-out contain a log line like ' Task timed out after 3.00 seconds'.

        Note: You MUST explicitly enable logging on API GW resources

        Returns:
            Dict: Log retrieval result
        """
        self.checkToolAccess()

        try:
            # Build the command arguments
            cmd = ['sam', 'logs']

            if resource_name:
                cmd.extend(['--name', resource_name])

            if config_env:
                cmd.extend(['--config-env', config_env])

            if config_file:
                cmd.extend(['--config-file', config_file])

            if cw_log_group:
                cmd.extend(['--cw-log-group'])
                for group in cw_log_group:
                    cmd.append(group)

            if start_time:
                cmd.extend(['--start-time', start_time])

            if end_time:
                cmd.extend(['--end-time', end_time])

            if save_params:
                cmd.extend(['--save-params'])

            if stack_name:
                cmd.extend(['--stack-name', stack_name])

            if profile:
                cmd.extend(['--profile', profile])

            if region:
                cmd.extend(['--region', region])

            # Execute the command
            logger.info(f'Executing command: {" ".join(cmd)}')
            stdout, stderr = await run_command(cmd)
            output = stdout.decode()
            message = (
                'Successfully fetched logs'
                if output != ''
                else 'No logs found for the specified resource'
            )
            return {
                'success': True,
                'message': message,
                'output': stdout.decode(),
            }
        except Exception as e:
            error_message = getattr(e, 'stderr', str(e))
            logger.error(f'Error fetching logs for resource: {error_message}')
            return {
                'success': False,
                'message': f'Failed to fetch logs for resource: {error_message}',
                'error': str(e),
            }
